{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "941988f8",
   "metadata": {},
   "source": [
    "| [03_language_model/03_Bert完形填空.ipynb](https://github.com/shibing624/nlp-tutorial/blob/main/03_language_model/03_Bert完形填空.ipynb)  | 基于transformers使用Bert模型做完形填空  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/03_language_model/03_Bert完形填空.ipynb) |\n",
    "\n",
    "# 完形填空\n",
    "\n",
    "利用语言模型，可以完成完形填空（fill mask），预测缺失的单词。\n",
    "\n",
    "当前，效果最好的语言模型是Bert系列的预训练语言模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install transformers"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e96497b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'score': 0.3006528317928314,\n",
      "  'sequence': '明 天 天 气 很 好?',\n",
      "  'token': 3698,\n",
      "  'token_str': '气'},\n",
      " {'score': 0.1056363433599472,\n",
      "  'sequence': '明 天 天 会 很 好?',\n",
      "  'token': 833,\n",
      "  'token_str': '会'},\n",
      " {'score': 0.09691553562879562,\n",
      "  'sequence': '明 天 天 还 很 好?',\n",
      "  'token': 6820,\n",
      "  'token_str': '还'},\n",
      " {'score': 0.08303344994783401,\n",
      "  'sequence': '明 天 天 就 很 好?',\n",
      "  'token': 2218,\n",
      "  'token_str': '就'},\n",
      " {'score': 0.08257968723773956,\n",
      "  'sequence': '明 天 天 都 很 好?',\n",
      "  'token': 6963,\n",
      "  'token_str': '都'}]\n",
      "******************************************\n",
      "[{'score': 0.6035364270210266,\n",
      "  'sequence': '明 天 心 情 很 好?',\n",
      "  'token': 2658,\n",
      "  'token_str': '情'},\n",
      " {'score': 0.2056306004524231,\n",
      "  'sequence': '明 天 心 会 很 好?',\n",
      "  'token': 833,\n",
      "  'token_str': '会'},\n",
      " {'score': 0.0558624342083931,\n",
      "  'sequence': '明 天 心 也 很 好?',\n",
      "  'token': 738,\n",
      "  'token_str': '也'},\n",
      " {'score': 0.026620158925652504,\n",
      "  'sequence': '明 天 心 就 很 好?',\n",
      "  'token': 2218,\n",
      "  'token_str': '就'},\n",
      " {'score': 0.015123367309570312,\n",
      "  'sequence': '明 天 心 态 很 好?',\n",
      "  'token': 2578,\n",
      "  'token_str': '态'}]\n",
      "******************************************\n",
      "[{'score': 0.18465296924114227,\n",
      "  'sequence': '张 亮 在 哪 里 任 啊?',\n",
      "  'token': 1557,\n",
      "  'token_str': '啊'},\n",
      " {'score': 0.17146149277687073,\n",
      "  'sequence': '张 亮 在 哪 里 任 的?',\n",
      "  'token': 4638,\n",
      "  'token_str': '的'},\n",
      " {'score': 0.09111949056386948,\n",
      "  'sequence': '张 亮 在 哪 里 任??',\n",
      "  'token': 136,\n",
      "  'token_str': '?'},\n",
      " {'score': 0.0695236399769783,\n",
      "  'sequence': '张 亮 在 哪 里 任 职?',\n",
      "  'token': 5466,\n",
      "  'token_str': '职'},\n",
      " {'score': 0.0664818212389946,\n",
      "  'sequence': '张 亮 在 哪 里 任 过?',\n",
      "  'token': 6814,\n",
      "  'token_str': '过'}]\n",
      "******************************************\n",
      "[{'score': 0.26780781149864197,\n",
      "  'sequence': '少 先 队 员 也 该 为 老 人 让 座 位 。',\n",
      "  'token': 738,\n",
      "  'token_str': '也'},\n",
      " {'score': 0.25287339091300964,\n",
      "  'sequence': '少 先 队 员 不 该 为 老 人 让 座 位 。',\n",
      "  'token': 679,\n",
      "  'token_str': '不'},\n",
      " {'score': 0.172683447599411,\n",
      "  'sequence': '少 先 队 员 应 该 为 老 人 让 座 位 。',\n",
      "  'token': 2418,\n",
      "  'token_str': '应'},\n",
      " {'score': 0.08332845568656921,\n",
      "  'sequence': '少 先 队 员 都 该 为 老 人 让 座 位 。',\n",
      "  'token': 6963,\n",
      "  'token_str': '都'},\n",
      " {'score': 0.052332185208797455,\n",
      "  'sequence': '少 先 队 员 ， 该 为 老 人 让 座 位 。',\n",
      "  'token': 8024,\n",
      "  'token_str': '，'}]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from transformers import pipeline\n",
    "\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "model_name = \"hfl/chinese-macbert-base\"\n",
    "\n",
    "nlp = pipeline(\"fill-mask\",\n",
    "               model=model_name,\n",
    "               tokenizer=model_name,\n",
    "               device=-1,  # gpu device id\n",
    "               )\n",
    "from pprint import pprint\n",
    "\n",
    "\n",
    "\n",
    "pprint(nlp(f\"明天天{nlp.tokenizer.mask_token}很好?\"))\n",
    "print(\"*\" * 42)\n",
    "pprint(nlp(f\"明天心{nlp.tokenizer.mask_token}很好?\"))\n",
    "print(\"*\" * 42)\n",
    "pprint(nlp(f\"张亮在哪里任{nlp.tokenizer.mask_token}?\"))\n",
    "print(\"*\" * 42)\n",
    "pprint(nlp(f\"少先队员{nlp.tokenizer.mask_token}该为老人让座位。\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2f5f5580",
   "metadata": {},
   "source": [
    "模型默认保存在：`~/.cache/huggingface/transformers`\n",
    "\n",
    "`hfl/chinese-macbert-base`模型介绍：\n",
    "https://huggingface.co/hfl/chinese-macbert-base?text=%E5%B7%B4%E9%BB%8E%E6%98%AF%5BMASK%5D%E5%9B%BD%E7%9A%84%E9%A6%96%E9%83%BD%E3%80%82\n",
    "\n",
    "\n",
    "不通过pipeline，可以自己写预测逻辑："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7fc64b26",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/xuming/opt/anaconda3/lib/python3.8/site-packages/transformers/models/auto/modeling_auto.py:588: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
      "  warnings.warn(\n",
      "Some weights of the model checkpoint at hfl/chinese-macbert-base were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "明天天会很好.\n",
      "明天天气很好.\n",
      "明天天都很好.\n",
      "明天天还很好.\n",
      "明天天也很好.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelWithLMHead, AutoTokenizer\n",
    "import torch\n",
    "\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
    "# model = AutoModelWithLMHead.from_pretrained(\"distilbert-base-cased\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelWithLMHead.from_pretrained(model_name)\n",
    "\n",
    "sequence = f\"明天天{nlp.tokenizer.mask_token}很好.\"\n",
    "input = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
    "mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]\n",
    "token_logits = model(input).logits\n",
    "mask_token_logits = token_logits[0, mask_token_index, :]\n",
    "top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()\n",
    "for token in top_5_tokens:\n",
    "    print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12e2e09f",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cb7a3b5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}